from typing import Optional, cast

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig


class PromptTransform:
    def _append_chat_histories(self, memory: TokenBufferMemory,
                               memory_config: MemoryConfig,
                               prompt_messages: list[PromptMessage],
                               model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
        rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
        histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
        prompt_messages.extend(histories)

        return prompt_messages

    def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
                              model_config: ModelConfigWithCredentialsEntity) -> int:
        rest_tokens = 2000

        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
        if model_context_tokens:
            model_type_instance = model_config.provider_model_bundle.model_type_instance
            model_type_instance = cast(LargeLanguageModel, model_type_instance)

            curr_message_tokens = model_type_instance.get_num_tokens(
                model_config.model,
                model_config.credentials,
                prompt_messages
            )

            max_tokens = 0
            for parameter_rule in model_config.model_schema.parameter_rules:
                if (parameter_rule.name == 'max_tokens'
                        or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
                    max_tokens = (model_config.parameters.get(parameter_rule.name)
                                  or model_config.parameters.get(parameter_rule.use_template)) or 0

            rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
            rest_tokens = max(rest_tokens, 0)

        return rest_tokens

    def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
                                          memory_config: MemoryConfig,
                                          max_token_limit: int,
                                          human_prefix: Optional[str] = None,
                                          ai_prefix: Optional[str] = None) -> str:
        """Get memory messages."""
        kwargs = {
            "max_token_limit": max_token_limit
        }

        if human_prefix:
            kwargs['human_prefix'] = human_prefix

        if ai_prefix:
            kwargs['ai_prefix'] = ai_prefix

        if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
            kwargs['message_limit'] = memory_config.window.size

        return memory.get_history_prompt_text(
            **kwargs
        )

    def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
                                               memory_config: MemoryConfig,
                                               max_token_limit: int) -> list[PromptMessage]:
        """Get memory messages."""
        return memory.get_history_prompt_messages(
            max_token_limit=max_token_limit,
            message_limit=memory_config.window.size
            if (memory_config.window.enabled
               and memory_config.window.size is not None
               and memory_config.window.size > 0)
            else 10
        )
